{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "from torch import nn"
   ],
   "id": "49bfb4fbf30bb0e2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "x = torch.randint(0, 255, (1, 128*128), dtype=torch.float32) # 生成tensor\n",
    "fc = nn.Linear(128*128, 2) # 线性变换 全链路层 128*128: 输入特征个数 ,2: 输出结果数\n",
    "y = fc(x)\n",
    "print(y)\n",
    "\n",
    "# 注意y的shape是(1, 2)\n",
    "output = nn.Softmax(dim=1)(y) # 求概率\n",
    "print(output)"
   ],
   "id": "61d48ff588ded8e"
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
